import gym
import d4rl_atari
import numpy as np
import pickle
from pathlib import Path
from collections import deque

games = ['asterix', 'breakout', 'space-invaders', 'space-invaders', 'seaquest', 'pong']

for game in games:
    print("Processing game: ", game)
    env = gym.make('{}-expert-v0'.format(game), stack=True) # -v{0, 1, 2, 3, 4} for datasets with the other random seeds

    frames = deque([], maxlen=num_frames)

    # dataset will be automatically downloaded into ~/.d4rl/datasets/[GAME]/[INDEX]/[EPOCH]
    dataset = env.get_dataset()

    # Split trajectories
    traj_ends = np.where(dataset['terminals'] == 1)[0]
    traj_start_ends = [] # [(0, traj_ends[0])]
    for i in range(len(traj_ends) - 2):
        traj_start_ends.append((traj_ends[i], traj_ends[i + 1]))

    observations_list = list()
    actions_list = list()
    rewards_list = list()
    terminal_list = list()

    for traj_start, traj_end in traj_start_ends:
        if traj_start == traj_end:
            continue
        observations_list.append(np.array(dataset['observations'][traj_start:traj_end]))
        actions_list.append(np.array(dataset['actions'][traj_start:traj_end]))
        rewards_list.append(np.array(dataset['rewards'][traj_start:traj_end][:,np.newaxis]))
        terminal_list.append(np.array(dataset['terminals'][traj_start:traj_end]))

    # Make np arrays
    observations_list = np.array(observations_list[:50])
    terminal_list = np.array(terminal_list[:50])
    actions_list = np.array(actions_list[:50])
    rewards_list = np.array(rewards_list[:50])

    print(np.mean([np.sum(_) for _ in rewards_list]), np.std([np.sum(_) for _ in rewards_list]))

    # Save demo in pickle file
    save_dir = Path("expert_demos/atari/{}".format(game))
    save_dir.mkdir(parents=True, exist_ok=True)
    snapshot_path = save_dir / 'expert_demos.pkl'
    payload = [
            observations_list, terminal_list, actions_list, rewards_list
        ]

    with open(str(snapshot_path), 'wb') as f:
        pickle.dump(payload, f)